import torch
from torch._C import set_flush_denormal
import torch.nn as nn
from torch.nn.modules.activation import ReLU
from torch.nn.modules.linear import Linear


class MLP(nn.Module):

    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(2, 4),
            nn.ReLU(),
            nn.Linear(4, 9),
            nn.ReLU(),
            nn.Linear(9, 3),
            nn.ReLU(),
            nn.Linear(3, 1)
        )

    def forward(self, x):
        x = self.network(x)
        return x